import torch
import numpy as np
from util.landscape.scheduler import get_job_indices, get_unplotted_indices_and_plotted_losses, ModelParallelScheduler
import util.landscape.util as util
import h5py
import os
import torch.distributed as dist
import time
import sys
import math
import copy

def calculate_loss(model, input, target, total_loss, criterion):
    # streams = [torch.cuda.Stream() for _ in range(len(models))]  
    # total_losses_list = np.zeros(len(models)) 
    
    # for input, target in dataloader:
    #     input, target = input.cuda(local_rank), target.cuda(local_rank)
    #     for i in range(len(models)):
    #         with torch.cuda.stream(streams[i]):
    #             models[i].eval()
    #             with torch.no_grad():
    #                 output = models[i](input)
    #                 loss = criterion(output, target)
    #                 total_losses_list[i] += loss*len(input)

    # return total_losses_list/len(args.dataset)
    # for stream in streams:
    #     stream.synchronize()
    model.eval()
    with torch.no_grad():
        # if args.show_mem_per_gpu and (rank==0 or rank is None):
        #     torch.cuda.reset_peak_memory_stats(f"cuda:{0}")

        output = model(input)
        loss = criterion(output, target)
        
        # if args.show_mem_per_gpu and (rank==0 or rank is None):
        #     memory_used = torch.cuda.max_memory_allocated(f"cuda:{0}") / (1024 ** 3)
        #     print(f"Memory used per GPU for feedforward: {memory_used:.2f} GB")
        #     sys.stdout.flush()
        #     args.show_mem_per_gpu = False

        total_loss += loss.item() * input.size(0)
    return total_loss
    # model.eval()
    # total_loss = 0.0
    # total_samples = 0

    # with torch.no_grad():
    #     for i, (input, target) in enumerate(dataloader):
    #         input, target = input.cuda(local_rank), target.cuda(local_rank)

    #         if args.show_mem_per_gpu and (rank==0 or rank is None):
    #             torch.cuda.reset_peak_memory_stats(args.device)

    #         output = model(input)
    #         loss = criterion(output, target)
            
    #         if args.show_mem_per_gpu and (rank==0 or rank is None):
    #             memory_used = torch.cuda.max_memory_allocated(args.device) / (1024 ** 3)
    #             print(f"Memory used per GPU for feedforward: {memory_used:.2f} GB")
    #             sys.stdout.flush()
    #             args.show_mem_per_gpu = False

    #         total_loss += loss.item() * input.size(0)
    #         total_samples += input.size(0)

    # avg_loss = total_loss / total_samples
    # return avg_loss

def plot_surface(model, criterion, dataloader, args, init, directions, rank, local_rank):
    if args.animation:
        surf_file = os.path.join(os.path.join(args.path_to_surf_file, args.id), "frame_"+str(args.current_epoch)+".h5")
    else:
        surf_file = os.path.join(args.path_to_surf_file, args.id+".h5")
    if not os.path.exists(surf_file) and (rank==0 or (rank is None)):
        if args.animation:
            if not os.path.exists(os.path.join(args.path_to_surf_file, args.id)):
                os.makedirs(os.path.join(args.path_to_surf_file, args.id))
        f= h5py.File(surf_file, 'w')
        f.close() 
    xcoordinates, ycoordinates = np.linspace(args.xmin, args.xmax, num=args.xnum), np.linspace(args.ymin, args.ymax, num=args.ynum)
    
    losses_new = -torch.ones((len(xcoordinates), len(ycoordinates)), device=f"cuda:{local_rank}", dtype=torch.float64)
    dist.barrier() # ensure all rank can read surface file
    inds, coordinates, losses_old = get_unplotted_indices_and_plotted_losses(surf_file, xcoordinates, ycoordinates, rank, args)
    indices, coords, inds_nums, max_tasks_per_gpu, remainder = get_job_indices(args, inds, coordinates, rank)
    
    # 创建调度器实例
    scheduler = ModelParallelScheduler(local_rank, rank, args, max_tasks_per_gpu)
    scheduler.schedule(args)
    for bundle_num, counts, idxs in util.batch_enumerate(indices, scheduler.streamcount):
        # set number of streams for each bundle and rank
        streams_len = util.get_stream_length_func(bundle_num, max_tasks_per_gpu, inds_nums, scheduler.streamcount)
        evaluation_start = time.time()
        coords_per_stream = [coords[i] for i in counts]
        models = [util.set_weights(copy.deepcopy(model), init, directions, coord) for coord in coords_per_stream if not (np.isnan(coord).all())]
        streams = [torch.cuda.Stream() for coord in coords_per_stream if not (np.isnan(coord).all())] # whether can we put the initiate out of the above for loop
        total_losses = np.zeros(len(counts))
        total_losses[np.isnan(idxs)]=-1
        total_input = 0
        if args.show_mem_per_gpu and (rank==0 or rank is None):
            torch.cuda.reset_peak_memory_stats(f"cuda:{0}")
        for input, target in dataloader:
            total_input+=len(input)
            input, target = input.cuda(local_rank), target.cuda(local_rank)
            for i in range(len(streams)):
                with torch.cuda.stream(streams[i]):
                    
                    total_losses[i] = calculate_loss(models[i], input, target, total_losses[i], criterion)
                    if (rank == 0 or (rank is None)) and i%100==0:
                        print(f"The stream {i+1} ends!")
        for stream in streams:
                stream.synchronize()
        if args.show_mem_per_gpu and (rank==0 or rank is None):
            memory_used = torch.cuda.max_memory_allocated(f"cuda:{0}") / (1024 ** 3)
            print(f"Memory used per GPU for feedforward: {memory_used:.2f} GB")
            sys.stdout.flush()
            args.show_mem_per_gpu = False
        losses = total_losses/total_input
        losses = torch.tensor(losses, device=f"cuda:{local_rank}")
        coords_per_stream = np.array(coords_per_stream)
        coords_per_stream = torch.tensor(coords_per_stream, device=f"cuda:{local_rank}")
        gathered_losses = None
        gathered_coords_per_stream = None
  
        if rank == 0 or (rank is None):
            gathered_losses = [torch.zeros_like(losses) for _ in range(args.world_size)] 
            gathered_coords_per_stream = [torch.zeros_like(coords_per_stream) for _ in range(args.world_size)] 
        dist.barrier()
        dist.gather(losses, gather_list=gathered_losses, dst=0)
        dist.gather(coords_per_stream, gather_list=gathered_coords_per_stream, dst=0)
        valid_idxs = idxs[~np.isnan(idxs)]  # 获取没有 NaN 的索引
        losses_new.view(-1)[valid_idxs.astype(int)] = losses[:streams_len(rank)]
    
        if rank==0 or (rank is None):
            print("Gathering Finishes!")
            if (counts[0] !=0) and (inds_nums[0]!=0):
                print(f"At least {100*(counts[-1]+1)/max_tasks_per_gpu}% points have been evaluated.")
                print(f"The estimated remaining time is about {util.seconds2days_hours_minutes_seconds((inds_nums[0]-counts[-1]-1)/scheduler.streamcount*evaluation_time)}.")
            f= h5py.File(surf_file, 'r+')
            print("Begin writing!")
            for r in range(args.world_size):
                for i in range(streams_len(r)):
                    if f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}" in f.keys():
                        print(f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}")
                    f[f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}"] = gathered_losses[r][i].item()
                    f.flush()
            f.close()
            print("Writing finishes!")
        evaluation_time=time.time()-evaluation_start
    return losses_old, losses_new



# if bundle_num!=(math.ceil(len(indices)/scheduler.streamcount)-1):
#                 for r in range(args.world_size):
#                     for i in range(streams_length):
#                         if f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}" in f.keys():
#                             print(f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}")
#                         f[f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}"] = gathered_losses[r][i].item()
#                         f.flush()
#             else:  
#                 for i in range((len(indices)%scheduler.streamcount)-1):
#                     for r in range(args.world_size):
#                         if f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}" in f.keys():
#                             print(f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}")
#                         f[f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}"] = gathered_losses[r][i].item()
#                         f.flush()
#                 for r in range(remainder): 
#                     if f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}" in f.keys():
#                         print(f"{gathered_coords_per_stream[r][i][0]}_{gathered_coords_per_stream[r][i][1]}")
#                     f[f"{gathered_coords_per_stream[args.world_size-r-1][scheduler.streamcount-1][0]}_{gathered_coords_per_stream[args.world_size-r-1][scheduler.streamcount-1][1]}"] = gathered_losses[args.world_size-r-1][scheduler.streamcount-1].item()
#                     f.flush() 






                # for r in range(args.world_size):
                #     for i in range(scheduler.streamcount):    
                #         if gathered_padded_coords_per_stream[r][i].all():
                #             f[f"{gathered_padded_coords_per_stream[r][i][0]}_{gathered_padded_coords_per_stream[r][i][1]}"] = gathered_padded_losses[r][i].item()
                #             f.flush()

# if (buddle_num==(math.ceil(indices/scheduler.streamcount)-1)) and (i == (len(indices)%scheduler.streamcount)-1) and r < args.world_size-remainder : 
#                                 pass